
import pickle 
import numpy as np
from OpenGL.GL import *
from OpenGL.GLU import *
from OpenGL.GLUT import *

global totalplayground
global age
global matrix
global database
global planttree
global overlappingf
global planttree
global rate
global ci
global fa
def init():
    global totalplayground
    global matrix
    global database
    global planttree
    global overlapping
    global planttree
    global rate
    global ci

    ci=0
    rate=np.float(0.5)
    
    overlapping=0.8
    totalplayground=0
    planttree=4000
    database={}
    database['totalfiredtree']=[]
    database['totaltree']=[]
    database['ffire']=[]
    matrix=np.zeros((200,200))
    totalplayground=np.size(matrix)

def changeplanttree():
    global planttree
    global rate
    global totalplayground
    total=0
    treetotal=0
    for m in range(np.size(matrix,1)):
        for n in range(np.size(matrix,1)):
            if matrix[m][n]==1:
                treetotal=treetotal+1
    
    newtree=rate*treetotal*(np.float(1)-np.float(treetotal)/np.float(totalplayground))
    planttree=np.int(newtree)
    



def totaltree():
    total=0
    for m in range(np.size(matrix,1)):
        for n in range(np.size(matrix,1)):
            if matrix[m][n]==1:
                total=total+1
    print "total tree:"+str(total)
    database['totaltree'].append(total)            
    



def grow():
    print 'grow'
    global planttree
    global matrix
    for n in range(np.int(planttree)):
        while True:
            x=np.random.randint(-1,np.size(matrix,1))
            y=np.random.randint(-1,np.size(matrix,1))
            if matrix[x][y]==1:
                pass
            else:
                matrix[x][y]=1

                break

    

def fire():
    global age
    global matrix
    treelist=[]
    for m in range(np.size(matrix,1)):
        for n in range(np.size(matrix,1)):
            if matrix[m][n]==1:
                treelist.append([m,n])
    index=np.random.randint(-1,len(treelist))
    x=treelist[index][0]
    y=treelist[index][1]
    matrix[x][y]=-1
    total=0
    while True:
        totalfire=0
        for m in range(np.size(matrix,1)):
            for n in range(np.size(matrix,1)):
                if matrix[m][n]==-1:
                    if m!=np.size(matrix,1)-1:
                        if matrix[m+1][n]==1:
                            matrix[m+1][n]=-1
                            totalfire=totalfire+1
                        if n!=np.size(matrix,1)-1:
                            if n!=0:
                                if matrix[m+1][n+1]==1:
                                    matrix[m+1][n+1]=-1
                                    totalfire=totalfire+1
                                if matrix[m+1][n-1]==1:
                                    matrix[m+1][n-1]=-1
                                    totalfire=totalfire+1
                
                    if m!=0:
                        if matrix[m-1][n]==1:
                            matrix[m-1][n]=-1
                            totalfire=totalfire+1
                        if n!=np.size(matrix,1)-1:
                            if n!=0:
                                if matrix[m-1][n+1]==1:
                                    matrix[m-1][n+1]=-1
                                    totalfire=totalfire+1
                                if matrix[m-1][n-1]==1:
                                    matrix[m-1][n-1]=-1
                                    totalfire=totalfire+1
                    if n!=np.size(matrix,1)-1:
                        if matrix[m][n+1]==1:
                            matrix[m][n+1]=-1
                            totalfire=totalfire+1
                    if n!=0:
                        if matrix[m][n-1]==1:
                            matrix[m][n-1]=-1
                            totalfire=totalfire+1

        if totalfire==0:
            print "fire turn out"
            break
        
        total=total+totalfire
    print "total fired tree:"+str(total+1)
    database['totalfiredtree'].append(total+1)


def ffire(fx,fy,am,an):
    global age
    global matrix
    matrix[fx][fy]=-1
    total=0
    while True:
        totalfire=0
        for m in range(am,am+10):
            for n in range(an,an+10):
                if matrix[m][n]==-1:
                    if m!=np.size(matrix,1)-1:
                        if matrix[m+1][n]==1:
                            matrix[m+1][n]=-1
                            totalfire=totalfire+1
                        if n!=np.size(matrix,1)-1:
                            if n!=0:
                                if matrix[m+1][n+1]==1:
                                    matrix[m+1][n+1]=-1
                                    totalfire=totalfire+1
                                if matrix[m+1][n-1]==1:
                                    matrix[m+1][n-1]=-1
                                    totalfire=totalfire+1
                
                    if m!=0:
                        if matrix[m-1][n]==1:
                            matrix[m-1][n]=-1
                            totalfire=totalfire+1
                        if n!=np.size(matrix,1)-1:
                            if n!=0:
                                if matrix[m-1][n+1]==1:
                                    matrix[m-1][n+1]=-1
                                    totalfire=totalfire+1
                                if matrix[m-1][n-1]==1:
                                    matrix[m-1][n-1]=-1
                                    totalfire=totalfire+1
                    if n!=np.size(matrix,1)-1:
                        if matrix[m][n+1]==1:
                            matrix[m][n+1]=-1
                            totalfire=totalfire+1
                    if n!=0:
                        if matrix[m][n-1]==1:
                            matrix[m][n-1]=-1
                            totalfire=totalfire+1

        if totalfire==0:
            #print "preventing  fire turn out"
            break
        total=total+totalfire
    #print "total preventing-fired tree:"+str(total+1)
    database['ffire'].append(total+1)
    turnout()


                    
    

def turnout():
    global matrix
    for m in range(np.size(matrix,1)):
        for n in range(np.size(matrix,1)):
            if matrix[m][n]==-1:
                matrix[m][n]=0


def caloverlaping(m,n):
    global matrix
    total=0
    for mm in range(0,10,1):
        for nn in range(0,10,1):
            if matrix[m+mm][n+nn]==1:
                total=total+1
    return total/100.0

def prevent(fa):
    total=0
    
    global matrix
    
    print "prevent start"
    
    check=0
    hasfire=0
    for m in range(0,200,10):
        for n in range(0,200,10):
            check=check+1
            tip=0
            #print "start "+str(m)+' '+str(n)
            
            if parttoal(m,n)>=fa:
                for fx in range(0,10,1):
                    for fy in range(0,10,1):
                        if matrix[m+fx][n+fx]!=1:
                            continue
                        #print "check fire "+str(m+fx)+" "+str(n+fy)
                        if (1-forecast(m+fx,n+fy,m,n))>=10/100.0:
                        
                            #print "yes ffire"
                            ffire(m+fx,n+fy,m,n)
                            hasfire=hasfire+1
                            tip=1
                            #print "ffire "+str(m+fx)+' '+str(n+fy)
                            break
                    if tip==1:
                        break
                
            
    if hasfire!=0:
        database['ffire'].append(0)

def parttoal(am,an):
    global matrix
    total=0
    for a in range(am,am+10):
        for b in range(am,am+10):
            if matrix[a][b]==1:
                total=total+1
    return total/100.0


def forecast(x,y,am,an):
    matrixforecast=np.zeros((200,200))
    global matrix
    matrixforecast[x][y]=-1
    total=0
    while True:
        totalfire=0
        for m in range(am,am+10):           
            for n in range(an,an+10):
                if matrixforecast[m][n]==-1:
                    if m!=np.size(matrix,1)-1:
                        if matrix[m+1][n]==1 and matrixforecast[m+1][n]!=-1:
                            matrixforecast[m+1][n]=-1
                            totalfire=totalfire+1
                            
                        if n!=np.size(matrix,1)-1:
                            if n!=0:
                                if matrix[m+1][n+1]==1 and matrixforecast[m+1][n+1]!=-1:
                                    matrixforecast[m+1][n+1]=-1
                                    totalfire=totalfire+1
                                    
                                if matrix[m+1][n-1]==1 and matrixforecast[m+1][n-1]!=-1:
                                    matrixforecast[m+1][n-1]=-1
                                    totalfire=totalfire+1
                                    
                
                    if m!=0:
                        if matrix[m-1][n]==1 and matrixforecast[m-1][n]!=-1:
                            matrixforecast[m-1][n]=-1
                            totalfire=totalfire+1
                            
                        if n!=np.size(matrix,1)-1:
                            if n!=0:
                                if matrix[m-1][n+1]==1 and matrixforecast[m-1][n+1]!=-1:
                                    matrixforecast[m-1][n+1]=-1
                                    totalfire=totalfire+1
                                    
                                if matrix[m-1][n-1]==1 and matrixforecast[m-1][n-1]!=-1:
                                    matrixforecast[m-1][n-1]=-1
                                    totalfire=totalfire+1
                                    
                    if n!=np.size(matrix,1)-1:
                        if matrix[m][n+1]==1 and matrixforecast[m][n+1]!=-1:
                            matrixforecast[m][n+1]=-1
                            totalfire=totalfire+1
                            
                    if n!=0:
                        if matrix[m][n-1]==1 and matrixforecast[m][n-1]!=-1:
                            matrixforecast[m][n-1]=-1
                            totalfire=totalfire+1
                            
        if totalfire==0:
            break
        total=total+totalfire
    return total/100.0

        
def doit():
    global ci
    global fa
    global database
    global result
    print "========="+str(ci)+"========="
    grow()
    totaltree()
    glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
    glColor3f(0.0, 1.0, 0)
    for m in range(np.size(matrix,1)):
        for n in range(np.size(matrix,1)):
            if (matrix[m][n]==1):
                glBegin(GL_QUADS)
                a=-1.0+2.0/np.size(matrix,1)*m
                b=-1.0+2.0/np.size(matrix,1)*n
                c=-1.0+2.0/np.size(matrix,1)*m+2.0/np.size(matrix,1)
                d=-1.0+2.0/np.size(matrix,1)*n+2.0/np.size(matrix,1)
                glVertex2f(a,d)
                glVertex2f(c,d)
                glVertex2f(c,b)
                glVertex2f(a,b)
                glEnd()
            else:
                pass
              
    glFlush()
    prevent(fa)

    
    glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
    glColor3f(0.0, 1.0, 0)
    for m in range(np.size(matrix,1)):
        for n in range(np.size(matrix,1)):
            if (matrix[m][n]==1):
                glBegin(GL_QUADS)
                a=-1.0+2.0/np.size(matrix,1)*m
                b=-1.0+2.0/np.size(matrix,1)*n
                c=-1.0+2.0/np.size(matrix,1)*m+2.0/np.size(matrix,1)
                d=-1.0+2.0/np.size(matrix,1)*n+2.0/np.size(matrix,1)
                glVertex2f(a,d)
                glVertex2f(c,d)
                glVertex2f(c,b)
                glVertex2f(a,b)
                glEnd()
            else:
                pass
              
    glFlush()
    fire()
    for m in range(np.size(matrix,1)):
        for n in range(np.size(matrix,1)):
            if (matrix[m][n]==1):
                glColor3f(0.0, 1.0, 0)
                glBegin(GL_QUADS)
                a=-1.0+2.0/np.size(matrix,1)*m
                b=-1.0+2.0/np.size(matrix,1)*n
                c=-1.0+2.0/np.size(matrix,1)*m+2.0/np.size(matrix,1)
                d=-1.0+2.0/np.size(matrix,1)*n+2.0/np.size(matrix,1)
                glVertex2f(a,d)
                glVertex2f(c,d)
                glVertex2f(c,b)
                glVertex2f(a,b)
                glEnd()
                    
            if (matrix[m][n]==-1):
                glColor3f(1.0, 0.0, 0)
                glBegin(GL_QUADS)
                a=-1.0+2.0/np.size(matrix,1)*m
                b=-1.0+2.0/np.size(matrix,1)*n
                c=-1.0+2.0/np.size(matrix,1)*m+2.0/np.size(matrix,1)
                d=-1.0+2.0/np.size(matrix,1)*n+2.0/np.size(matrix,1)
                glVertex2f(a,d)
                glVertex2f(c,d)
                glVertex2f(c,b)
                glVertex2f(a,b)
                glEnd()
    glFlush()
    turnout()
    changeplanttree()
    ci=ci+1
    if ci==50:
        result.append([fa,database])
        with open('database'+str(fa)+'.pickle','wb') as f:
            pickle.dump(result,f)
        result=[]
        fa=fa+0.05
        print "================================fa :"+str(fa)+"================================"
        if fa>0.5:
            exit()
        init()
        ci=0



fa=0.05
init()
result=[]
init() 
glutInit()
glutInitDisplayMode(GLUT_SINGLE | GLUT_RGBA)
glutInitWindowSize(600, 600)
glutCreateWindow("First")
glutDisplayFunc(doit)
glutIdleFunc(doit)
glutMainLoop()




